vLLM Paged Attention
paged_attention_v1
- input
out: shape [num_seqs, num_heads, head_size]
query: shape [num_seqs, num_heads, head_size]
key_cache: shape [num_blocks, num_heads, head_size/x, block_size, x]
value_cache: shape [num_blocks, num_heads, head_size, block_size]
block_tables: shape [num_seqs, max_num_blocks_per_seq]
num_kv_heads: num_heads
context_lens: num_seqs
x代表的是一个向量化的大小
- CUDA设置
gird: shape (num_heads, num_seqs, num_partition)
其中num_partition在不采用的时候为1block: shape (NUM_THREADS)
其中,对于attn_metadata,prefill的数据在前面,decode的数据在后面
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ----------------------|
# |-- query_len ---|
num_prefills=num_prefills,
slot_mapping=slot_mapping_tensor, # token对应在table中的slot id
num_prefill_tokens=num_prefill_tokens, # prefill token的数目
num_decode_tokens=num_decode_tokens, # decode token的数目
seq_lens=seq_lens, # 各个句子的长度
seq_lens_tensor=seq_lens_tensor, # tensor类型的seq_lens,和上面没什么区别
max_query_len=max_query_len, # prefill阶段的query最大值,假如采用了chunk prefill,query_len,而不是context_len。比如484第一次chunked prefill算了20,则第二次max_query_len为464
max_prefill_seq_len=max_prefill_seq_len, # 可看上图
max_decode_seq_len=max_decode_seq_len, # Maximum sequence length among decode batch. 0 if there are prefill requests only.
query_start_loc=query_start_loc, # if the subquery length is [4, 6], it is [0, 4, 10]. 这个是query,假如decode,query是1
seq_start_loc=seq_start_loc, # if the sequence length is [4, 6], it is [0, 4, 10]. 这个是sequence
context_lens_tensor=context_lens_tensor, # context len的tensor,cache decode中不存储
block_tables=block_tables, # (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
use_cuda_graph=use_captured_graph,
测试中的数据
# 第一次 两个484传进去,prefill第一个484,第二个chunk28
num_prefills=2, num_prefill_tokens=512, num_decode_tokens=0, seq_lens=[28, 484], seq_lens_tensor=tensor([ 28, 484], device='cuda:0', dtype=torch.int32), max_query_len=484, max_prefill_seq_len=484, max_decode_seq_len=0, query_start_loc=tensor([ 0, 28, 512], device='cuda:0', dtype=torch.int32), seq_start_loc=tensor([ 0, 28, 512], device='cuda:0', dtype=torch.int32), context_lens_tensor=tensor([0, 0], device='cuda:0', dtype=torch.int32),
# 第二次,第一个推理485,第二个推理剩下的prefill
num_prefills=1, num_prefill_tokens=456, num_decode_tokens=1, seq_lens=[485, 484], seq_lens_tensor=tensor([485, 484], device='cuda:0', dtype=torch.int32), max_query_len=456, max_prefill_seq_len=484, max_decode_seq_len=485, query_start_loc=tensor([ 0, 1, 457], device='cuda:0', dtype=torch.int32), seq_start_loc=tensor([ 0, 485, 969], device='cuda:0', dtype=torch.int32), context_lens_tensor=tensor([484, 28], device='cuda:0', dtype=torch.int32),
# 第三次 两个decode
num_prefills=0, num_prefill_tokens=0, num_decode_tokens=2, slot_mapping=tensor([2051253, 2050756], device='cuda:0'), seq_lens=[486, 485], seq_lens_tensor=tensor([486, 485], device='cuda:0', dtype=torch.int32), max_query_len=1, max_prefill_seq_len=0, max_decode_seq_len=486, query_start_loc=tensor([0, 1, 2], device='cuda:0', dtype=torch.int32), seq_start_loc=tensor([ 0, 486, 971], device='cuda:0', dtype=torch.int32), context_lens_tensor=tensor([485, 484], device='cuda:0', dtype=torch.int32)
如果对应一个剩下的prefill,block table照旧,slot mapping代表需要传进去的数据。
forward:
# decode部分
decode_query: [query decode的token数量, num_heads, head_size]
key_cache: [总的num_blocks, block_size, num_heads, head_size]
value_cache: [总的num_blocks, block_size, num_heads, head_size]
ops.paged_attention_v1参数
print(query.shape)
print(key_cache.shape)
print(value_cache.shape)
print(num_kv_heads)
print(block_tables.shape)
print(seq_lens)
print(block_size)
print(max_seq_len)
30b
torch.Size([1, 56, 128])
torch.Size([195, 56, 16, 16, 8]) 56是num_heads 16(2)和8(4)相乘是head_sizes 16(3)是block_sizes
torch.Size([195, 56, 128, 16])
56
torch.Size([1, 13])
tensor([204], dtype=torch.int32)
16
204
125m
torch.Size([1, 12, 64])
torch.Size([7281, 12, 8, 16, 8])
torch.Size([7281, 12, 64, 16])
12
torch.Size([1, 13])
tensor([204], dtype=torch.int32)
16
204
attn_backend_impl的数据
print(num_heads)
print(head_size)
print(scale)
print(num_kv_heads)
print(alibi_slopes)
print(sliding_window)
print(kv_cache_dtype)
print(blocksparse_params)
opt-125m
12
64
0.125
12
None
None
auto
None
opt-30b
56
128
0.08838834764831845
56
None
None
auto
None
torch_sdpa make_metadata数据
prefill [['0', 195], ['1', 195], ['2', 195], ['3', 195], ['4', 195]]时参数:
is_prompt: True
seq_lens: [195, 195, 195, 195, 195]
seq_lens_tensor: None
max_decode_seq_len: None
num_prefills: 5
num_prefill_tokens: 975
num_decode_tokens: 0
block_tables: tensor([])
slot_mapping: tensor([15984, 15985, 15986, 15987, 15988, 15989, 15990, 15991, 15992, 15993,
15994, 15995, 15996, 15997, 15998, 15999, 15968, 15969, 15970, 15971,
15972, 15973, 15974, 15975, 15976, 15977, 15978, 15979, 15980, 15981,
15982, 15983, 15952, 15953, 15954, 15955, 15956, 15957, 15958, 15959,
15960, 15961, 15962, 15963, 15964, 15965, 15966, 15967, 15936, 15937,
15938, 15939, 15940, 15941, 15942, 15943, 15944, 15945, 15946, 15947,
15948, 15949, 15950, 15951, 15920, 15921, 15922, 15923, 15924, 15925,
15926, 15927, 15928, 15929, 15930, 15931, 15932, 15933, 15934, 15935,
15904, 15905, 15906, 15907, 15908, 15909, 15910, 15911, 15912, 15913,
15914, 15915, 15916, 15917, 15918, 15919, 15888, 15889, 15890, 15891,
15892, 15893, 15894, 15895, 15896, 15897, 15898, 15899, 15900, 15901,
15902, 15903, 15872, 15873, 15874, 15875, 15876, 15877, 15878, 15879,
15880, 15881, 15882, 15883, 15884, 15885, 15886, 15887, 15856, 15857,
15858, 15859, 15860, 15861, 15862, 15863, 15864, 15865, 15866, 15867,
15868, 15869, 15870, 15871, 15840, 15841, 15842, 15843, 15844, 15845,
15846, 15847, 15848, 15849, 15850, 15851, 15852, 15853, 15854, 15855,
15824, 15825, 15826, 15827, 15828, 15829, 15830, 15831, 15832, 15833,
15834, 15835, 15836, 15837, 15838, 15839, 15808, 15809, 15810, 15811,
15812, 15813, 15814, 15815, 15816, 15817, 15818, 15819, 15820, 15821,
15822, 15823, 15792, 15793, 15794, 15776, 15777, 15778, 15779, 15780,
15781, 15782, 15783, 15784, 15785, 15786, 15787, 15788, 15789, 15790,
15791, 15760, 15761, 15762, 15763, 15764, 15765, 15766, 15767, 15768,
15769, 15770, 15771, 15772, 15773, 15774, 15775, 15744, 15745, 15746,
15747, 15748, 15749, 15750, 15751, 15752, 15753, 15754, 15755, 15756,
15757, 15758, 15759, 15728, 15729, 15730, 15731, 15732, 15733, 15734,
15735, 15736, 15737, 15738, 15739, 15740, 15741, 15742, 15743, 15712,
15713, 15714, 15715, 15716, 15717, 15718, 15719, 15720, 15721, 15722,
15723, 15724, 15725, 15726, 15727, 15696, 15697, 15698, 15699, 15700,
15701, 15702, 15703, 15704, 15705, 15706, 15707, 15708, 15709, 15710,
15711, 15680, 15681, 15682, 15683, 15684, 15685, 15686, 15687, 15688,
15689, 15690, 15691, 15692, 15693, 15694, 15695, 15664, 15665, 15666,
15667, 15668, 15669, 15670, 15671, 15672, 15673, 15674, 15675, 15676,
15677, 15678, 15679, 15648, 15649, 15650, 15651, 15652, 15653, 15654,
15655, 15656, 15657, 15658, 15659, 15660, 15661, 15662, 15663, 15632,
15633, 15634, 15635, 15636, 15637, 15638, 15639, 15640, 15641, 15642,
15643, 15644, 15645, 15646, 15647, 15616, 15617, 15618, 15619, 15620,
15621, 15622, 15623, 15624, 15625, 15626, 15627, 15628, 15629, 15630,
15631, 15600, 15601, 15602, 15603, 15604, 15605, 15606, 15607, 15608,
15609, 15610, 15611, 15612, 15613, 15614, 15615, 15584, 15585, 15586,
15568, 15569, 15570, 15571, 15572, 15573, 15574, 15575, 15576, 15577,
15578, 15579, 15580, 15581, 15582, 15583, 15552, 15553, 15554, 15555,
15556, 15557, 15558, 15559, 15560, 15561, 15562, 15563, 15564, 15565,
15566, 15567, 15536, 15537, 15538, 15539, 15540, 15541, 15542, 15543,
15544, 15545, 15546, 15547, 15548, 15549, 15550, 15551, 15520, 15521,
15522, 15523, 15524, 15525, 15526, 15527, 15528, 15529, 15530, 15531,
15532, 15533, 15534, 15535, 15504, 15505, 15506, 15507, 15508, 15509,
15510, 15511, 15512, 15513, 15514, 15515, 15516, 15517, 15518, 15519,
15488, 15489, 15490, 15491, 15492, 15493, 15494, 15495, 15496, 15497,
15498, 15499, 15500, 15501, 15502, 15503, 15472, 15473, 15474, 15475,
15476, 15477, 15478, 15479, 15480, 15481, 15482, 15483, 15484, 15485,
15486, 15487, 15456, 15457, 15458, 15459, 15460, 15461, 15462, 15463,
15464, 15465, 15466, 15467, 15468, 15469, 15470, 15471, 15440, 15441,
15442, 15443, 15444, 15445, 15446, 15447, 15448, 15449, 15450, 15451,
15452, 15453, 15454, 15455, 15424, 15425, 15426, 15427, 15428, 15429,
15430, 15431, 15432, 15433, 15434, 15435, 15436, 15437, 15438, 15439,
15408, 15409, 15410, 15411, 15412, 15413, 15414, 15415, 15416, 15417,
15418, 15419, 15420, 15421, 15422, 15423, 15392, 15393, 15394, 15395,
15396, 15397, 15398, 15399, 15400, 15401, 15402, 15403, 15404, 15405,
15406, 15407, 15376, 15377, 15378, 15360, 15361, 15362, 15363, 15364,
15365, 15366, 15367, 15368, 15369, 15370, 15371, 15372, 15373, 15374,
15375, 15344, 15345, 15346, 15347, 15348, 15349, 15350, 15351, 15352,
15353, 15354, 15355, 15356, 15357, 15358, 15359, 15328, 15329, 15330,
15331, 15332, 15333, 15334, 15335, 15336, 15337, 15338, 15339, 15340,
15341, 15342, 15343, 15312, 15313, 15314, 15315, 15316, 15317, 15318,
15319, 15320, 15321, 15322, 15323, 15324, 15325, 15326, 15327, 15296,
15297, 15298, 15299, 15300, 15301, 15302, 15303, 15304, 15305, 15306,
15307, 15308, 15309, 15310, 15311, 15280, 15281, 15282, 15283, 15284,
15285, 15286, 15287, 15288, 15289, 15290, 15291, 15292, 15293, 15294,
15295, 15264, 15265, 15266, 15267, 15268, 15269, 15270, 15271, 15272,
15273, 15274, 15275, 15276, 15277, 15278, 15279, 15248, 15249, 15250,
15251, 15252, 15253, 15254, 15255, 15256, 15257, 15258, 15259, 15260,
15261, 15262, 15263, 15232, 15233, 15234, 15235, 15236, 15237, 15238,
15239, 15240, 15241, 15242, 15243, 15244, 15245, 15246, 15247, 15216,
15217, 15218, 15219, 15220, 15221, 15222, 15223, 15224, 15225, 15226,
15227, 15228, 15229, 15230, 15231, 15200, 15201, 15202, 15203, 15204,
15205, 15206, 15207, 15208, 15209, 15210, 15211, 15212, 15213, 15214,
15215, 15184, 15185, 15186, 15187, 15188, 15189, 15190, 15191, 15192,
15193, 15194, 15195, 15196, 15197, 15198, 15199, 15168, 15169, 15170,
15152, 15153, 15154, 15155, 15156, 15157, 15158, 15159, 15160, 15161,
15162, 15163, 15164, 15165, 15166, 15167, 15136, 15137, 15138, 15139,
15140, 15141, 15142, 15143, 15144, 15145, 15146, 15147, 15148, 15149,
15150, 15151, 15120, 15121, 15122, 15123, 15124, 15125, 15126, 15127,
15128, 15129, 15130, 15131, 15132, 15133, 15134, 15135, 15104, 15105,
15106, 15107, 15108, 15109, 15110, 15111, 15112, 15113, 15114, 15115,
15116, 15117, 15118, 15119, 15088, 15089, 15090, 15091, 15092, 15093,
15094, 15095, 15096, 15097, 15098, 15099, 15100, 15101, 15102, 15103,
15072, 15073, 15074, 15075, 15076, 15077, 15078, 15079, 15080, 15081,
15082, 15083, 15084, 15085, 15086, 15087, 15056, 15057, 15058, 15059,
15060, 15061, 15062, 15063, 15064, 15065, 15066, 15067, 15068, 15069,
15070, 15071, 15040, 15041, 15042, 15043, 15044, 15045, 15046, 15047,
15048, 15049, 15050, 15051, 15052, 15053, 15054, 15055, 15024, 15025,
15026, 15027, 15028, 15029, 15030, 15031, 15032, 15033, 15034, 15035,
15036, 15037, 15038, 15039, 15008, 15009, 15010, 15011, 15012, 15013,
15014, 15015, 15016, 15017, 15018, 15019, 15020, 15021, 15022, 15023,
14992, 14993, 14994, 14995, 14996, 14997, 14998, 14999, 15000, 15001,
15002, 15003, 15004, 15005, 15006, 15007, 14976, 14977, 14978, 14979,
14980, 14981, 14982, 14983, 14984, 14985, 14986, 14987, 14988, 14989,
14990, 14991, 14960, 14961, 14962])
decode [['0', 196], ['1', 196], ['2', 196], ['3', 196], ['4', 196]]时数据
is_prompt: False
slot_mapping: tensor([15795, 15587, 15379, 15171, 14963])
seq_lens: [196, 196, 196, 196, 196]
seq_lens_tensor: tensor([196, 196, 196, 196, 196], dtype=torch.int32)
max_decode_seq_len: 196
num_prefill_tokens: 0
num_decode_tokens: 5
num_prefills: 0
block_tables: tensor([[999, 998, 997, 996, 995, 994, 993, 992, 991, 990, 989, 988, 987],
[986, 985, 984, 983, 982, 981, 980, 979, 978, 977, 976, 975, 974],
[973, 972, 971, 970, 969, 968, 967, 966, 965, 964, 963, 962, 961],
[960, 959, 958, 957, 956, 955, 954, 953, 952, 951, 950, 949, 948],
[947, 946, 945, 944, 943, 942, 941, 940, 939, 938, 937, 936, 935]],
dtype=torch.int32)